//  Copyright (c) Meta Platforms, Inc. and affiliates.
//
//  Licensed under the Apache License, Version 2.0 (the "License");
//  you may not use this file except in compliance with the License.
//  You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
//  Unless required by applicable law or agreed to in writing, software
//  distributed under the License is distributed on an "AS IS" BASIS,
//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//  See the License for the specific language governing permissions and
//  limitations under the License.
//

// Original NVIDIA copyright notice:

/******************************************************************************
 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

#include <cuda_fp16.h>
#include <cmath>

namespace fmha {

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Sum_ {
  static constexpr bool IS_SUM = true;
  static inline __device__ float apply(float x, float y) {
    return x + y;
  }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Max_ {
  static constexpr bool IS_SUM = false;
  static inline __device__ float apply(float x, float y) {
    return x > y ? x : y;
  }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

inline __device__ float apply_exp_(float x, float max) {
  return __expf(x - max);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

inline __device__ float apply_exp2_(float x, float max) {
  return exp2f(x - max);
  // With fast-math, this produces the same PTX instruction as the assembly
  // below float diff = x - max; float res; asm ("ex2.approx.ftz.f32 %0,
  // %1;\n\t" : "=f"(res) : "f"(diff)); return res;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int COLS>
struct ReadType {};
template <>
struct ReadType<4> {
  using T = float;
};
template <>
struct ReadType<8> {
  using T = float2;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Cta_tile, typename Kernel_traits>
struct Smem_tile_reduce {
  // Helper class to distribute MMA tiles reduced over rows per warp over quads.

  // The Mma tile.
  using Mma_tile = fmha::Hmma_tile<Cta_tile>;

  // The number of MMAs in M/N dimensions.
  static constexpr int MMAS_M = Mma_tile::MMAS_M;
  static constexpr int MMAS_N = Mma_tile::MMAS_N;

  static constexpr int WARPS_M = Cta_tile::WARPS_M;
  static constexpr int WARPS_N = Cta_tile::WARPS_N;

  static constexpr int ROWS = WARPS_M * MMAS_M * 16;
  static constexpr int COLS = WARPS_N;
  static_assert(COLS == 4 || COLS == 8);
  static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;
  static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);
  static constexpr int ELTS_PER_TILE = ROWS * COLS;

  static constexpr int THREADS_PER_GROUP =
      Kernel_traits::Gmem_tile_o::THREADS_PER_ROW;
  // TD [2022-05-02]: No longer true if head_dim != 64
  // static_assert(THREADS_PER_GROUP == 16); // DEBUG
  static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP;
  static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS;
  static_assert(LOOPS == 1);

  using read_t = typename ReadType<COLS>::T;

  __device__ inline Smem_tile_reduce(float* smem_, const int tidx) {
    int lane = tidx % 32;
    int warp = tidx / 32;

    int warp_m = warp % WARPS_M;
    int warp_n = warp / WARPS_M;

    qid_ = lane % 4;
    int qp = lane / 4;

    // Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps.
    // This won't affect reading as we assume commutative reduction ops.
    const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN);
    smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col];
    smem_read_ = &reinterpret_cast<read_t*>(
        smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_];
    smem_read_row_ =
        &reinterpret_cast<read_t*>(smem_)[warp_m * 16 * MMAS_M * 4 + qid_];
  }

  __device__ inline void store(float (&frag)[2 * MMAS_M]) {
    if (qid_ == 0) {
#pragma unroll
      for (int mi = 0; mi < MMAS_M; mi++) {
        int offset = mi * 16 * WARPS_N;
        smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0];
        smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1];
      }
    }
  }

  __device__ inline void load(read_t (&frag)[2 * MMAS_M]) {
#pragma unroll
    for (int mi = 0; mi < MMAS_M; mi++) {
      int offset = mi * 16 * 4;
      frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4];
      frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4];
    }
  }

  __device__ inline void load_row(read_t (&frag)[MMAS_M], int row) {
#pragma unroll
    for (int mi = 0; mi < MMAS_M; mi++) {
      int offset = mi * 16 * 4;
      frag[mi] = smem_read_row_[offset + 0 * 8 * 4 + row * 4];
    }
  }

  int qid_;
  float* smem_write_;
  read_t* smem_read_;
  read_t* smem_read_row_;
};

template <typename Cta_tile, typename Kernel_traits>
struct Softmax_base {
  // The Mma tile.
  using Mma_tile = fmha::Hmma_tile<Cta_tile>;

  // The number of MMAs in M/N dimensions.
  static constexpr int MMAS_M = Mma_tile::MMAS_M;
  static constexpr int MMAS_N = Mma_tile::MMAS_N;

  // The number of groups of warp such that we have at most 4 warps writing
  // consecutive elements.
  static constexpr int GROUPS = fmha::DivUpConstexpr(Cta_tile::WARPS_N, 4);
  // The number of elements that we are going to store per row.
  static constexpr int ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS;
  // The number of rows.
  static constexpr int ROWS = Cta_tile::M * GROUPS;
  // The total number of elements.
  static constexpr int ELEMENTS = ROWS * ELEMENTS_PER_ROW;

  // Ctor.
  template <typename Params>
  inline __device__ Softmax_base(const Params& params, void* smem, int tidx)
      : // packed_mask_ptr_(reinterpret_cast<const
        // char*>(params.packed_mask_ptr)),
        smem_(reinterpret_cast<float*>(smem)),
        tidx_(tidx) {
    // Move to the 1st mask loaded by the thread+ tidx;
    // packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx *
    // sizeof(uint32_t);

    // Extract the position in the warp.
    int warp = tidx / Cta_tile::THREADS_PER_WARP;
    int lane = tidx % Cta_tile::THREADS_PER_WARP;

    // Decompose the warp index into M and N.
    int warp_m = warp % Cta_tile::WARPS_M;
    int warp_n = warp / Cta_tile::WARPS_M;

    // Decompose the warp-n index into group/position-inside-the-group.
    int warp_g = warp_n / ELEMENTS_PER_ROW;
    int warp_i = warp_n % ELEMENTS_PER_ROW;

    // The location written by the threads.
    int write_row =
        warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4;
    int write_col = warp_i;

    // Assemble the write pointer.
    smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col];

    // Assemble the read pointer.
    smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4];
  }

  template <bool zero = false, typename Mask>
  inline __device__ void apply_mask(const Mask& mask) {
#pragma unroll
    for (int mi = 0; mi < MMAS_M; ++mi) {
#pragma unroll
      for (int ii = 0; ii < 2; ++ii) {
#pragma unroll
        for (int ni = 0; ni < MMAS_N; ++ni) {
#pragma unroll
          for (int jj = 0; jj < 4; ++jj) {
            if (!mask.is_valid(mi, ni, ii, jj)) {
              elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY;
            }
          }
        }
      }
    }
  }

  // Apply the exp to all the elements.
  template <bool max_in_base2 = false, bool elt_in_base2 = false>
  inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) {
#pragma unroll
    for (int mi = 0; mi < MMAS_M * 2; ++mi) {
      // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
      // max * log_2(e)) This allows the compiler to use the ffma
      // instruction instead of fadd and fmul separately.
      constexpr float kLog2e = M_LOG2E;
      const float max_base2 = max_in_base2 ? max[mi] : max[mi] * kLog2e;
#pragma unroll
      for (int ni = 0; ni < MMAS_N * 4; ++ni) {
        // elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]);
        elt_[mi][ni] = apply_exp2_(
            elt_in_base2 ? elt_[mi][ni] : elt_[mi][ni] * kLog2e, max_base2);
      }
    }
  }

  // Apply the exp to all the elements.
  template <bool scale_max = true>
  inline __device__ void scale_apply_exp(
      const float (&max)[MMAS_M * 2],
      const float scale_) {
    const float max_scale = scale_max ? scale_ * M_LOG2E : M_LOG2E;
    const float scale = scale_ * M_LOG2E;
#pragma unroll
    for (int mi = 0; mi < MMAS_M * 2; ++mi) {
      // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
      // max * log_2(e)) This allows the compiler to use the ffma
      // instruction instead of fadd and fmul separately.
      const float max_scaled = max[mi] * max_scale;
#pragma unroll
      for (int ni = 0; ni < MMAS_N * 4; ++ni) {
        elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * scale, max_scaled);
      }
    }
  }

  // Apply the exp to all the elements.
  template <bool max_in_base2 = false>
  inline __device__ void apply_exp_col(const float (&max)[MMAS_N * 4]) {
#pragma unroll
    for (int ni = 0; ni < MMAS_N * 4; ++ni) {
      constexpr float kLog2e = M_LOG2E;
      const float max_base2 = max_in_base2 ? max[ni] : max[ni] * kLog2e;
#pragma unroll
      for (int mi = 0; mi < MMAS_M * 2; ++mi) {
        elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2);
      }
    }
  }
  // inline __device__ void apply_exp_col(const float (&max)[MMAS_N]) {
  //     constexpr float kLog2e = M_LOG2E;
  //     #pragma unroll
  //     for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
  //         float max_base2 = max_in_base2 ? max[ni / 4] : max[ni / 4] *
  //         kLog2e; max_base2 = __shfl_sync(0xffffffff, max_base2, (ni % 4) * 8
  //         + threadIdx.x % 8); #pragma unroll for( int mi = 0; mi < MMAS_M *
  //         2; ++mi ) {
  //             elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2);
  //         }
  //     }
  // }

  template <bool encode_dropout_in_sign_bit = false>
  inline __device__ void apply_dropout(Philox& ph, uint32_t p_dropout_in_uint) {
    // We encode the dropout pattern in the sign bit of the non-negative
    // softmax to distinguish from pre-existing zeros
    auto encode_dropout = [](bool keep, float val) {
      return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
    };
#pragma unroll
    for (int mi = 0; mi < MMAS_M * 2; mi++) {
#pragma unroll
      for (int ni = 0; ni < MMAS_N; ni++) {
        uint4 tmp = ph();
        // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
        //     printf("ni = %d, ph  Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y,
        //     tmp.z, tmp.w);
        // }
        elt_[mi][4 * ni + 0] =
            encode_dropout(tmp.x <= p_dropout_in_uint, elt_[mi][4 * ni + 0]);
        elt_[mi][4 * ni + 1] =
            encode_dropout(tmp.y <= p_dropout_in_uint, elt_[mi][4 * ni + 1]);
        elt_[mi][4 * ni + 2] =
            encode_dropout(tmp.z <= p_dropout_in_uint, elt_[mi][4 * ni + 2]);
        elt_[mi][4 * ni + 3] =
            encode_dropout(tmp.w <= p_dropout_in_uint, elt_[mi][4 * ni + 3]);
      }
    }
  }

  template <bool encode_dropout_in_sign_bit = false>
  inline __device__ void apply_dropout(
      Philox& ph0,
      Philox& ph1,
      uint32_t p_dropout_in_uint) {
    // We encode the dropout pattern in the sign bit of the non-negative
    // softmax to distinguish from pre-existing zeros
    auto encode_dropout = [](bool keep, float val) {
      return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
    };
#pragma unroll
    for (int mi = 0; mi < MMAS_M * 2; mi++) {
      static_assert(MMAS_N % 2 == 0);
#pragma unroll
      for (int ni = 0; ni < MMAS_N; ni += 2) {
        uint4 tmp = ph0();
        // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
        //     printf("ni = %d, ph0, Philox: %u, %u, %u, %u\n", ni, tmp.x,
        //     tmp.y, tmp.z, tmp.w);
        // }
        elt_[mi][4 * ni + 0] =
            encode_dropout(tmp.x <= p_dropout_in_uint, elt_[mi][4 * ni + 0]);
        elt_[mi][4 * ni + 1] =
            encode_dropout(tmp.y <= p_dropout_in_uint, elt_[mi][4 * ni + 1]);
        elt_[mi][4 * ni + 2] =
            encode_dropout(tmp.z <= p_dropout_in_uint, elt_[mi][4 * ni + 2]);
        elt_[mi][4 * ni + 3] =
            encode_dropout(tmp.w <= p_dropout_in_uint, elt_[mi][4 * ni + 3]);
        tmp = ph1();
        // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
        //     printf("ni = %d, ph1, Philox: %u, %u, %u, %u\n", ni + 1, tmp.x,
        //     tmp.y, tmp.z, tmp.w);
        // }
        elt_[mi][4 * (ni + 1) + 0] = encode_dropout(
            tmp.x <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 0]);
        elt_[mi][4 * (ni + 1) + 1] = encode_dropout(
            tmp.y <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 1]);
        elt_[mi][4 * (ni + 1) + 2] = encode_dropout(
            tmp.z <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 2]);
        elt_[mi][4 * (ni + 1) + 3] = encode_dropout(
            tmp.w <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 3]);
      }
    }
  }

  template <bool encode_dropout_in_sign_bit = false>
  inline __device__ void apply_dropout_16bits(
      Philox& ph,
      uint16_t p_dropout_in_uint16_t) {
    // We encode the dropout pattern in the sign bit of the non-negative
    // softmax to distinguish from pre-existing zeros
    auto encode_dropout = [](bool keep, float val) {
      return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
    };
#pragma unroll
    for (int mi = 0; mi < MMAS_M; mi++) {
#pragma unroll
      for (int ni = 0; ni < MMAS_N; ni++) {
        uint16_t tmp[8];
        fmha::uint4_to_ushort8(ph(), tmp);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
//     printf("ni = %d, ph  Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z,
//     tmp.w);
// }
#pragma unroll
        for (int ii = 0; ii < 2; ++ii) {
#pragma unroll
          for (int jj = 0; jj < 4; ++jj) {
            elt_[mi * 2 + ii][4 * ni + jj] = encode_dropout(
                tmp[ii * 4 + jj] <= p_dropout_in_uint16_t,
                elt_[mi * 2 + ii][4 * ni + jj]);
          }
        }
      }
    }
  }

  template <bool encode_dropout_in_sign_bit = false>
  inline __device__ void apply_dropout_16bits(
      Philox& ph0,
      Philox& ph1,
      uint16_t p_dropout_in_uint16_t) {
    // We encode the dropout pattern in the sign bit of the non-negative
    // softmax to distinguish from pre-existing zeros
    auto encode_dropout = [](bool keep, float val) {
      return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
    };
#pragma unroll
    for (int mi = 0; mi < MMAS_M; mi++) {
      static_assert(MMAS_N % 2 == 0);
#pragma unroll
      for (int ni = 0; ni < MMAS_N; ni += 2) {
        uint16_t tmp[8];
        fmha::uint4_to_ushort8(ph0(), tmp);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
//     printf("ni = %d, ph  Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z,
//     tmp.w);
// }
#pragma unroll
        for (int ii = 0; ii < 2; ++ii) {
#pragma unroll
          for (int jj = 0; jj < 4; ++jj) {
            elt_[mi * 2 + ii][4 * ni + jj] = encode_dropout(
                tmp[ii * 4 + jj] <= p_dropout_in_uint16_t,
                elt_[mi * 2 + ii][4 * ni + jj]);
          }
        }
        fmha::uint4_to_ushort8(ph1(), tmp);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
//     printf("ni = %d, ph  Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z,
//     tmp.w);
// }
#pragma unroll
        for (int ii = 0; ii < 2; ++ii) {
#pragma unroll
          for (int jj = 0; jj < 4; ++jj) {
            elt_[mi * 2 + ii][4 * (ni + 1) + jj] = encode_dropout(
                tmp[ii * 4 + jj] <= p_dropout_in_uint16_t,
                elt_[mi * 2 + ii][4 * (ni + 1) + jj]);
          }
        }
      }
    }
  }

  // Scale all the elements.
  inline __device__ void scale(const float (&sum)[MMAS_M * 2]) {
    // Precompute the inverse sum to normalize. Without -use_fast_math, it makes
    // a huge deal.
    float inv_sum[MMAS_M * 2];
#pragma unroll
    for (int mi = 0; mi < MMAS_M * 2; ++mi) {
      inv_sum[mi] =
          (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi];
    }

// Update the values.
#pragma unroll
    for (int mi = 0; mi < MMAS_M * 2; ++mi) {
#pragma unroll
      for (int ni = 0; ni < MMAS_N * 4; ++ni) {
        elt_[mi][ni] *= inv_sum[mi];
      }
    }
  }

  // Subtract all elements by dp_sum
  inline __device__ void subtract_dp_sum(const float (&dp_sum)[MMAS_M * 2]) {
#pragma unroll
    for (int mi = 0; mi < MMAS_M * 2; ++mi) {
#pragma unroll
      for (int ni = 0; ni < MMAS_N * 4; ++ni) {
        elt_[mi][ni] -= dp_sum[mi];
      }
    }
  }

  // The pointer to the mask.
  const char* packed_mask_ptr_;
  // Shared memory for the CTA-wide reduction.
  float *smem_, *smem_write_, *smem_read_;
  // The current thread index.
  int tidx_;
  // The elements.
  float elt_[MMAS_M * 2][MMAS_N * 4];
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Cta_tile, typename Kernel_traits>
struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
  // The base class.
  using Base = Softmax_base<Cta_tile, Kernel_traits>;
  // The fragment.
  using Fragment_a = fmha::Fragment_a<fmha::Row>;

  static_assert(Fragment_a::NUM_REGS == 4);

  static constexpr int WARPS_M = Cta_tile::WARPS_M;
  static constexpr int WARPS_N = Cta_tile::WARPS_N;
  // The MMAs.
  static constexpr int MMAS_M = Base::MMAS_M;
  static constexpr int MMAS_N = Base::MMAS_N;

  // The accumulators.
  using Accumulator = fmha::Fragment_accumulator;
  using Accumulator_out = Fragment<uint16_t, 8>;
  static_assert(Accumulator_out::NUM_REGS == 4);

  static_assert(std::is_same<Accumulator::Data_type, float>::value);

  using Smem_tile_red = Smem_tile_reduce<Cta_tile, Kernel_traits>;
  static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N);
  // Ctor.
  template <typename Params>
  inline __device__ Softmax(const Params& params, void* smem, int tidx)
      : Base(params, smem, tidx),
        params_scale_bmm1_(params.scale_bmm1),
        smem_sum_(static_cast<float*>(smem), tidx),
        smem_max_(
            static_cast<float*>(smem) + Smem_tile_red::ELTS_PER_TILE,
            tidx) {}

  // Pack the data to a fragment for the next GEMM.
  template <int K, int M>
  inline __device__ void pack(Fragment_a (&dst)[K][M]) const {
#pragma unroll
    for (int mi = 0; mi < M; ++mi) {
#pragma unroll
      for (int ki = 0; ki < K; ++ki) {
        // 1st row - 4 elements per row.
        float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0];
        float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1];
        float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2];
        float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3];

        // 2nd row - 4 elements per row.
        float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0];
        float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1];
        float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2];
        float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3];

        // Pack to 4 registers.
        dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01);
        dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11);
        dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03);
        dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13);
      }
    }
  }

  // Scale FP32 fragments
  inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) {
    const float scalef =
        reinterpret_cast<const float&>(this->params_scale_bmm1_);

#pragma unroll
    for (int mi = 0; mi < MMAS_M; ++mi) {
#pragma unroll
      for (int ni = 0; ni < MMAS_N; ++ni) {
        // 1st row - 4 elements per row.
        this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef;
        this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef;
        this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef;
        this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef;
        // 2nd row - 4 elements per row.
        this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef;
        this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef;
        this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef;
        this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef;
      }
    }
  }

  // Scale FP32 fragments
  inline __device__ void unpack_noscale(
      const Accumulator (&acc)[MMAS_M][MMAS_N]) {
#pragma unroll
    for (int mi = 0; mi < MMAS_M; ++mi) {
#pragma unroll
      for (int ni = 0; ni < MMAS_N; ++ni) {
        // 1st row - 4 elements per row.
        this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0);
        this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1);
        this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4);
        this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5);
        // 2nd row - 4 elements per row.
        this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2);
        this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3);
        this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6);
        this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7);
      }
    }
  }

  template <bool zero_init = true, typename Operator>
  __device__ inline void thread_reduce_(
      float (&frag)[2 * MMAS_M],
      Operator& op) {
#pragma unroll
    for (int mi = 0; mi < 2 * MMAS_M; mi++) {
      frag[mi] =
          zero_init ? this->elt_[mi][0] : op(frag[mi], this->elt_[mi][0]);
#pragma unroll
      for (int ni = 1; ni < 4 * MMAS_N; ni++) {
        frag[mi] = op(frag[mi], this->elt_[mi][ni]);
      }
    }
  }

  template <bool zero_init = true, typename Operator>
  __device__ inline void reduce_(
      float (&frag)[2 * MMAS_M],
      Operator& op,
      Smem_tile_red& smem_red) {
    thread_reduce_<zero_init>(frag, op);
    quad_reduce(frag, frag, op);
    smem_red.store(frag);
    __syncthreads();
    typename Smem_tile_red::read_t tmp[2 * MMAS_M];
    smem_red.load(tmp);
    quad_allreduce(frag, tmp, op);
  }

  template <bool zero_init = true>
  __device__ inline void reduce_max(float (&frag)[2 * MMAS_M]) {
    MaxOp<float> max;
    reduce_<zero_init>(frag, max, smem_max_);
  }

  __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]) {
    SumOp<float> sum;
    reduce_(frag, sum, smem_sum_);
  }

  template <bool zero_init = true>
  __device__ inline void reduce_sum_before_sync_(float (&frag)[2 * MMAS_M]) {
    SumOp<float> sum;
    thread_reduce_<zero_init>(frag, sum);
    quad_reduce(frag, frag, sum);
    smem_sum_.store(frag);
  }

  template <int NROWS, typename Operator>
  __device__ inline void reduce_after_sync_(
      float (&frag)[NROWS][MMAS_M],
      const int (&rows)[NROWS],
      Operator& op,
      Smem_tile_red& smem_red) {
#pragma unroll
    for (int ii = 0; ii < NROWS; ii++) {
      typename Smem_tile_red::read_t tmp[MMAS_M];
      smem_red.load_row(tmp, rows[ii]);
      quad_allreduce(frag[ii], tmp, op);
    }
  }

  template <int NROWS>
  __device__ inline void reduce_sum_after_sync_(
      float (&frag)[NROWS][MMAS_M],
      const int (&rows)[NROWS]) {
    SumOp<float> sum;
    reduce_after_sync_(frag, rows, sum, smem_sum_);
  }

  template <int NROWS>
  __device__ inline void reduce_max_after_sync_(
      float (&frag)[NROWS][MMAS_M],
      const int (&rows)[NROWS]) {
    MaxOp<float> max;
    reduce_after_sync_(frag, rows, max, smem_max_);
  }

  const uint32_t params_scale_bmm1_;
  Smem_tile_red smem_max_;
  Smem_tile_red smem_sum_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace fmha
